#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from pprint import pformat
from typing import Dict, Optional
import collections
import itertools
import logging
import pathlib
import random
import shutil
import time

import psutil
import pynvml
import torch
import torch.cuda
import torch.distributed
import torch.nn.functional
import torch.nn
import torch.optim
import torch.utils.tensorboard

import fairdiplomacy.selfplay.metrics
import fairdiplomacy.selfplay.pg.data_loader
import fairdiplomacy.selfplay.pg.vtrace
import fairdiplomacy.selfplay.search.data_loader
import fairdiplomacy.selfplay.remote_metric_logger
from fairdiplomacy.models.consts import POWERS, MAX_SEQ_LEN
from fairdiplomacy.models.base_strategy_model.load_model import (
    load_base_strategy_model_model_and_args,
)
from fairdiplomacy.models.state_space import EOS_IDX
from fairdiplomacy.agents.base_strategy_model_wrapper import forward_model_with_output_transform
from fairdiplomacy.utils.timing_ctx import TimingCtx
from fairdiplomacy.selfplay.trainer_state import NetTrainingState, TrainerState
from fairdiplomacy.selfplay.execution_context import ExecutionContext
from fairdiplomacy.selfplay.pg.rollout import order_logits_to_action_logprobs
from fairdiplomacy.selfplay.paths import get_rendezvous_path, get_torch_ddp_init_fname
from fairdiplomacy.selfplay.search.rollout import ReSearchRolloutBatch
from fairdiplomacy.selfplay.search.search_utils import (
    compute_search_policy_entropy,
    compute_search_policy_cross_entropy_sampled,
    evs_to_policy,
)
from fairdiplomacy.utils.exception_handling_process import ExceptionHandlingProcess
from fairdiplomacy.utils.multiprocessing_spawn_context import get_multiprocessing_ctx
import conf.conf_cfgs
import heyhi
import nest

mp = get_multiprocessing_ctx()

# Exported checkpoint that could be loaded by agents.
CKPT_MAIN_DIR = pathlib.Path("ckpt")
CKPT_VALUE_DIR = pathlib.Path("ckpt_value")
CKPT_TPL = "epoch%06d.ckpt"


REQUEUE_CKPT = pathlib.Path("requeue.ckpt")


def compute_entropy_loss(logits, mask):
    """Return the entropy loss, i.e., the negative entropy of the policy."""
    policy = torch.nn.functional.softmax(logits, dim=-1)
    log_policy = torch.nn.functional.log_softmax(logits, dim=-1)
    return torch.sum(policy * log_policy * mask.unsqueeze(-1)) / (mask.sum() + 1e-5)


def compute_policy_gradient_loss(action_logprobs, advantages):
    return -torch.mean(action_logprobs * advantages.detach())


def compute_sampled_entropy_loss(action_logprobs):
    return (action_logprobs * (1 + action_logprobs.detach())).mean()


def sample_and_compute_sampled_entropy_loss(model, obs):
    # [T, B, ...] -> [T * B, ...]
    flat_obs = nest.map(lambda x: x.flatten(end_dim=1), obs)
    global_order_idxs, policy_action_logprobs, _ = forward_model_with_output_transform(
        model, flat_obs, temperature=1.0, need_value=False, pad_to_max=True
    )
    # [T * B, POWERS, SEQ_LEN]
    mask = (global_order_idxs != EOS_IDX).any(-1)
    policy_action_logprobs = policy_action_logprobs.view(-1)[mask.view(-1)]
    loss = compute_sampled_entropy_loss(policy_action_logprobs)
    entropy = -policy_action_logprobs.mean()
    return loss, entropy


def to_onehot(indices, max_value):
    y_onehot = torch.zeros((len(indices), max_value), device=indices.device)
    y_onehot.scatter_(1, indices.unsqueeze(1), 1)
    return y_onehot


def build_optimizer(net, optimizer_cfg):
    opt_name = optimizer_cfg.WhichOneof("optimizer")
    assert opt_name, f"Config must define an agent type: {optimizer_cfg}"
    opt_class_cfg = getattr(optimizer_cfg, opt_name)
    if opt_name == "adam" and optimizer_cfg.adam.weight_decay is not None:
        logging.info("Using AdamW optimizer")
        optimizer = torch.optim.AdamW(net.parameters(), **heyhi.conf_to_dict(opt_class_cfg))
    else:
        opt_class = {"adam": torch.optim.Adam, "sgd": torch.optim.SGD}[opt_name]
        optimizer = opt_class(net.parameters(), **heyhi.conf_to_dict(opt_class_cfg))

    scheduler_list = []
    if optimizer_cfg.warmup_epochs is not None and optimizer_cfg.warmup_epochs > 0:
        scheduler_list.append(
            torch.optim.lr_scheduler.LambdaLR(
                optimizer, lambda epoch: min(1.0, (epoch + 1) / optimizer_cfg.warmup_epochs)
            )
        )
    if optimizer_cfg.cosine_decay_epochs:
        scheduler_list.append(
            torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, optimizer_cfg.cosine_decay_epochs
            )
        )
    if optimizer_cfg.step_decay_epochs:
        scheduler_list.append(
            torch.optim.lr_scheduler.StepLR(
                optimizer, optimizer_cfg.step_decay_epochs, gamma=optimizer_cfg.step_decay_factor
            )
        )

    if optimizer_cfg.warmup_decay:
        scheduler_list.append(
            torch.optim.lr_scheduler.LambdaLR(
                optimizer, _warmup_decay_schedule(optimizer_cfg.warmup_decay)
            )
        )

    if not scheduler_list:
        scheduler = None
    elif len(scheduler_list) == 1:
        scheduler = scheduler_list[0]
    else:
        # Only with pytorch 1.10. Hence the stanza above.
        scheduler = torch.optim.lr_scheduler.ChainedScheduler(scheduler_list)

    return optimizer, scheduler


def vtrace_from_logprobs_no_batch(**kwargs):
    kwargs = {k: (v.unsqueeze(1) if v.shape else v.unsqueeze(0)) for k, v in kwargs.items()}
    ret = fairdiplomacy.selfplay.pg.vtrace.from_importance_weights(**kwargs)
    return type(ret)(*[i.squeeze(1) for i in ret])


def clip_grad_norm_(parameters, max_norm, norm_type=2):
    """Copied from Pytorch 1.5. Faster version for grad norm."""
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = torch.norm(
        torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type
    )
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.detach().mul_(clip_coef)
    return total_norm


def load_or_create_base_strategy_model_model(model_path: str, *, reset_weights: bool, **kwargs):
    """Thin wrapper over load_base_strategy_model_model_and_args with required reset_weights arg."""
    model, args = load_base_strategy_model_model_and_args(
        model_path, skip_weight_loading=reset_weights, **kwargs
    )
    return model, args


class ExploitTrainer:
    def __init__(
        self, cfg: conf.conf_cfgs.ExploitTask, ectx: ExecutionContext, random_seed: int,
    ):
        # This constructor is called by 3 different sets of processes. See ExecutionContext
        # On data generation workers (additional processes directly spawned by slurm), heyhi.is_master() is false.
        # These workers just stay permanently stuck in the DataLoader, generating data for the training master.
        self.ectx = ectx

        if cfg.search_rollout.batch_size:
            self.research = True
            self.is_policy_being_trained = bool(cfg.search_policy_weight)
            self.is_value_being_trained = cfg.critic_weight > 0
        else:
            self.research = False
            self.is_policy_being_trained = self.is_value_being_trained = True
        logging.info(f"self.research={self.research}")
        logging.info(f"self.is_policy_being_trained={self.is_policy_being_trained}")
        logging.info(f"self.is_value_being_trained={self.is_value_being_trained}")

        assert cfg.use_distributed_data_parallel == ectx.using_ddp
        if ectx.using_ddp:
            assert self.research, "DistributedDataParallel only supported for research"
            assert cfg.search_rollout.batch_size % self.ectx.ddp_world_size == 0
            assert cfg.search_rollout.buffer.capacity % self.ectx.ddp_world_size == 0

            if self.ectx.is_training_master:
                logging.info(
                    f"This is a training master process, rank {self.ectx.training_ddp_rank}"
                )
            elif self.ectx.is_training_helper:
                logging.info(
                    f"This is a training helper process, rank {self.ectx.training_ddp_rank}"
                )
        else:
            assert self.ectx.ddp_world_size == 1

        cfg = cfg.to_editable()
        if heyhi.conf_is_set(cfg, "optimizer.lr") and not cfg.optimizer.WhichOneof("optimizer"):
            heyhi.conf_set(cfg, "optimizer.adam.lr", heyhi.conf_get(cfg, "optimizer.lr"))
        if self.research:
            assert cfg.bootstrap_offline_targets, "Using online targets is not supported anymore"
            if cfg.search_rollout.extra_params.use_ev_targets:
                logging.info("Setting cfg.search_rollout.extra_params.collect_search_evs = true")
                cfg.search_rollout.extra_params.collect_search_evs = True
                logging.info(
                    "Setting cfg.search_rollout.extra_params.collect_search_policies = true"
                )
                cfg.search_rollout.extra_params.collect_search_policies = True
            if cfg.search_policy_weight:
                logging.info(
                    "Setting cfg.search_rollout.extra_params.collect_search_policies = %s",
                    self.is_policy_being_trained,
                )
                cfg.search_rollout.extra_params.collect_search_policies = (
                    self.is_policy_being_trained
                )
            if cfg.WhichOneof("maybe_search_ev_loss") is not None:
                logging.info("Setting cfg.search_rollout.extra_params.collect_search_evs = True")
                cfg.search_rollout.extra_params.collect_search_evs = True
            assert (
                self.is_policy_being_trained
                or not cfg.search_rollout.extra_params.use_trained_policy
            )
            assert (
                self.is_value_being_trained
                or not cfg.search_rollout.extra_params.use_trained_value
            )

            for key in cfg.search_rollout.to_dict():
                h2h_eval_cfg = getattr(cfg.search_rollout, key)
                if key.startswith("h2h_eval") and h2h_eval_cfg.tag:
                    assert not h2h_eval_cfg.HasField(
                        "use_trained_policy"
                    ), f"Do not set use_trained_policy manually: {h2h_eval_cfg}"
                    h2h_eval_cfg.use_trained_policy = self.is_policy_being_trained
                    logging.info(
                        f"Setting cfg.search_rollout.{key}.use_trained_policy = {h2h_eval_cfg.use_trained_policy}"
                    )
                    assert not h2h_eval_cfg.HasField(
                        "use_trained_value"
                    ), f"Do not set use_trained_policy manually: {h2h_eval_cfg}"
                    h2h_eval_cfg.use_trained_value = self.is_value_being_trained
                    logging.info(
                        f"Setting cfg.search_rollout.{key}.use_trained_value = {h2h_eval_cfg.use_trained_value}"
                    )
            if cfg.search_rollout.benchmark_only:
                logging.info("Benchmak only mode. Disabling h2h evals and sit checks")
                for key in cfg.search_rollout.to_dict():
                    h2h_eval_cfg = getattr(cfg.search_rollout, key)
                    if key.startswith("h2h_eval"):
                        h2h_eval_cfg.tag = ""
                cfg.search_rollout.test_situation_eval.do_eval = False
                assert cfg.num_train_gpus == 1, "Benchmark mode. Set num_train_gpus to 1"
                logging.info("Benchmak only mode. Disabling warmup")
                cfg.search_rollout.warmup_batches = 1
                logging.info("Benchmak only mode. Disabling buffer saving")
                cfg.search_rollout.buffer.ClearField("save_at")
                logging.info("Benchmak only mode. Setting overuse to 1")
                cfg.search_rollout.enforce_train_gen_ratio = 1.0

        cfg = cfg.to_frozen()
        if cfg.value_optimizer is not None:
            assert self.research
            assert self.is_value_being_trained
            assert cfg.value_model_path
        if cfg.optimizer.cosine_decay_epochs is not None and cfg.trainer.max_epochs:
            assert cfg.optimizer.cosine_decay_epochs + 1 >= cfg.trainer.max_epochs
        if cfg.seed > 0:
            torch.manual_seed(cfg.seed)
        assert (
            cfg.search_policy_update_prob >= 1.0
            or cfg.search_rollout.extra_params.run_do_prob >= 1.0
        ), "Cannot mix search_policy_update_prob and search_rollout.extra_params.run_do_prob"

        self.cfg: conf.conf_cfgs.ExploitTask = cfg
        self.device: str
        self.state: TrainerState
        self.last_epoch_state: Dict

        assert self.research or not cfg.value_model_path, "Not supported in PG mode"

        self.random = random.Random(random_seed)

    # Initialize and prepare the data for training.
    # Data generation workers are expected to never return from this function.
    def _init_data_loader_workers_may_never_return(self):
        if self.research:
            # Data generation workers never return from this constructor.
            self.data_loader = fairdiplomacy.selfplay.search.data_loader.DataLoader(
                self.cfg.model_path,
                self.cfg.search_rollout,
                num_train_gpus=self.cfg.num_train_gpus,
                ectx=self.ectx,
            )
            assert (
                not self.ectx.is_rollouter
            ), "Data generation machines are expected to stay in DataLoader"
        else:
            self.data_loader = fairdiplomacy.selfplay.pg.data_loader.DataLoader(
                self.cfg.model_path, self.cfg.rollout
            )

    # Initialize the state of training, taking into account checkpoints and requeing.
    def _init_training_device_and_state(self):
        if not torch.cuda.is_available():
            logging.warning("No CUDA found!")
            self.device = "cpu"
        else:
            if self.cfg.use_distributed_data_parallel:
                # For now, ddp rank equals gpu index
                self.device = f"cuda:{self.ectx.training_ddp_rank}"  # Training device.
            else:
                self.device = "cuda"  # Training device.

        if self.cfg.value_model_path is not None:
            net, net_args = load_or_create_base_strategy_model_model(
                self.cfg.model_path,
                reset_weights=self.cfg.reset_agent_weights,
                map_location=self.device,
                eval=True,
                override_has_policy=True,
                override_has_value=False,
            )
            logging.info("Policy model args:\n%s", pformat(net_args, indent=2))
            optim, lr_scheduler = build_optimizer(net, self.cfg.optimizer)

            value_net, value_net_args = load_or_create_base_strategy_model_model(
                self.cfg.value_model_path,
                reset_weights=self.cfg.reset_agent_weights,
                map_location=self.device,
                eval=True,
                override_has_policy=False,
                override_has_value=True,
            )
            logging.info("Value model args:\n%s", pformat(value_net_args, indent=2))

            value_optim, value_lr_scheduler = build_optimizer(
                value_net, self.cfg.value_optimizer or self.cfg.optimizer
            )
            value_state = NetTrainingState(
                args=value_net_args,
                model=value_net,
                optimizer=value_optim,
                scheduler=value_lr_scheduler,
            )
        else:
            net, net_args = load_or_create_base_strategy_model_model(
                self.cfg.model_path,
                reset_weights=self.cfg.reset_agent_weights,
                map_location=self.device,
                eval=True,
            )
            logging.info("Policy+value model args:\n%s", pformat(net_args, indent=2))
            optim, lr_scheduler = build_optimizer(net, self.cfg.optimizer)
            value_state = None
        self.state = TrainerState(
            net_state=NetTrainingState(
                model=net, optimizer=optim, scheduler=lr_scheduler, args=net_args,
            ),
            value_net_state=value_state,
        )

        force_sync = False
        if REQUEUE_CKPT.exists():
            logging.info("Found requeue checkpoint: %s", REQUEUE_CKPT.resolve())
            self.state = TrainerState.load(REQUEUE_CKPT, self.state, self.device)
        elif self.cfg.requeue_ckpt_path:
            force_sync = True
            logging.info("Using explicit requeue checkpoint: %s", self.cfg.requeue_ckpt_path)
            p = pathlib.Path(self.cfg.requeue_ckpt_path)
            assert p.exists(), p
            self.state = TrainerState.load(p, self.state, self.device)
        else:
            force_sync = True
            if self.state.value_net_state is not None:
                if CKPT_VALUE_DIR.exists() and list(CKPT_VALUE_DIR.iterdir()):
                    last_ckpt = max(CKPT_VALUE_DIR.iterdir(), key=str)
                    logging.info(
                        "Found existing VALUE checkpoint folder. Will load last one: %s", last_ckpt
                    )
                    self.state.value_net_state = NetTrainingState.load(
                        last_ckpt, self.state.value_net_state, self.device
                    )
                else:
                    logging.info("No VALUE checkpoint found")
            if self.state.net_state is not None:
                if CKPT_MAIN_DIR.exists() and list(CKPT_MAIN_DIR.iterdir()):
                    last_ckpt = max(CKPT_MAIN_DIR.iterdir(), key=str)
                    logging.info(
                        "Found existing MAIN checkpoint folder. Will load last one: %s", last_ckpt
                    )
                    self.state.net_state = NetTrainingState.load(
                        last_ckpt, self.state.net_state, self.device
                    )
                else:
                    logging.info("No MAIN checkpoint found")

        # Training device and model state should be correct now, so now make sure that
        # the model is sent to the data generation workers if it's not there yet.
        self.maybe_send_model_to_workers(force_sync=force_sync)

        if self.cfg.use_distributed_data_parallel:
            torch_ddp_init_fname = get_torch_ddp_init_fname()
            logging.info(f"Using {torch_ddp_init_fname} to coordinate launch of ddp processes.")
            logging.info("Waiting for distributed data parallel helpers to sync up")
            torch.distributed.init_process_group(
                "nccl",
                init_method=f"file://{torch_ddp_init_fname}",
                rank=self.ectx.training_ddp_rank,
                world_size=self.ectx.ddp_world_size,
            )
            logging.info("Distributed data parallel helpers synced, proceeding with training")

    def _init_metrics_logger(self):
        if self.ectx.is_training_master:
            self.logger_server = fairdiplomacy.selfplay.remote_metric_logger.MetricLoggingServer(
                self.cfg, ("alphadip" if self.research else "rl_pg")
            )
        self.logger = fairdiplomacy.selfplay.remote_metric_logger.get_remote_logger(
            # DDP training helpers don't log anything. Can hack this here if needed for debugging them.
            is_dummy=not self.ectx.is_training_master
        )
        self.logger.log_config(self.cfg)

    def terminate(self):
        if hasattr(self, "data_loader"):
            logging.info("Killing data loader")
            self.data_loader.terminate()
        if hasattr(self, "logger"):
            logging.info("Closing the logger")
            self.logger.close()
        if hasattr(self, "logger_server"):
            logging.info("Closing the logger server")
            self.logger_server.terminate()

    def on_requeue(self):
        logging.info("Pre-termination callback")
        self.terminate()

    def maybe_send_model_to_workers(self, force_sync=False):
        if self.ectx.is_training_helper:
            return
        sync_main = (
            force_sync or self.state.global_step % self.cfg.trainer.save_sync_checkpoint_every == 0
        )
        sync_value = force_sync or (
            self.state.global_step
            % (
                self.cfg.trainer.save_sync_value_checkpoint_every
                or self.cfg.trainer.save_sync_checkpoint_every
            )
            == 0
        )
        if self.research:
            if sync_main:
                self.data_loader.update_model(
                    self.state.net_state.model,
                    global_step=self.state.global_step,
                    args=self.state.net_state.args,
                    epoch=self.state.epoch_id,
                    as_policy=True,
                    as_value=self.state.value_net_state is None,
                )
            if sync_value and self.state.value_net_state is not None:
                self.data_loader.update_model(
                    self.state.value_net_state.model,
                    global_step=self.state.global_step,
                    args=self.state.value_net_state.args,
                    epoch=self.state.epoch_id,
                    as_policy=False,
                    as_value=True,
                )
        else:
            if sync_main:
                self.data_loader.update_model(
                    self.state.net_state.model,
                    global_step=self.state.net_state.global_step,
                    args=self.state.net_state.args,
                    epoch=self.state.epoch_id,
                )
            assert self.state.value_net_state is None

    def __call__(self):
        # we call this immediately to send the data workers off to the right places
        # After this returns, we're in the master process or training helper processes only.
        self._init_data_loader_workers_may_never_return()

        self._init_training_device_and_state()

        self._init_metrics_logger()

        if self.ectx.is_training_master:
            CKPT_MAIN_DIR.mkdir(exist_ok=True, parents=True)
            if self.state.value_net_state is not None:
                CKPT_VALUE_DIR.mkdir(exist_ok=True, parents=True)

        self.state.model.train()
        if self.state.value_net_state is not None:
            self.state.value_net_state.model.train()
        if self.cfg.trainer.train_as_eval:
            assert self.state.value_net_state is None, "Not supported"
            self.state.model.eval()
            # Cast cuDNN RNN back to train mode.
            self.state.model.apply(_lstm_to_train)
        elif self.cfg.trainer.train_encoder_as_eval:
            assert self.state.value_net_state is None, "Not supported"
            self.state.model.encoder.eval()
        elif self.cfg.trainer.train_decoder_as_eval:
            assert self.state.value_net_state is None, "Not supported"
            self.state.model.policy_decoder.eval()
            self.state.model.policy_decoder.apply(_lstm_to_train)
        elif self.cfg.trainer.train_as_eval_but_batchnorm:
            assert self.state.value_net_state is None, "Not supported"
            self.state.model.eval()
            self.state.model.apply(_lstm_to_train)
            self.state.model.apply(_bn_to_train)

        logging.info(
            f"Enforcing self.cfg.training_permute_powers = {self.cfg.training_permute_powers}"
        )
        self.state.net_state.model.set_training_permute_powers(self.cfg.training_permute_powers)
        if self.state.value_net_state is not None:
            self.state.value_net_state.model.set_training_permute_powers(
                self.cfg.training_permute_powers
            )

        if not self.research:
            assert self.cfg.num_train_gpus == 1, "Only one training GPU for policy gradients"
            self.state.model.to(self.device)
        else:
            if self.cfg.num_train_gpus > 1:
                assert torch.cuda.device_count() == 8, "Can only go multi-gpu on full machine"
                if self.cfg.use_distributed_data_parallel:
                    self.state.net_state.model = torch.nn.parallel.DistributedDataParallel(
                        self.state.net_state.model,
                        device_ids=(self.ectx.training_ddp_rank,),
                        output_device=self.ectx.training_ddp_rank,
                    )
                    if self.state.value_net_state is not None:
                        self.state.value_net_state.model = torch.nn.parallel.DistributedDataParallel(
                            self.state.value_net_state.model,
                            device_ids=(self.ectx.training_ddp_rank,),
                            output_device=self.ectx.training_ddp_rank,
                        )
                else:
                    self.state.net_state.model = torch.nn.DataParallel(
                        self.state.net_state.model,
                        device_ids=tuple(range(self.cfg.num_train_gpus)),
                    )
                    if self.state.value_net_state is not None:
                        self.state.value_net_state.model = torch.nn.DataParallel(
                            self.state.value_net_state.model,
                            device_ids=tuple(range(self.cfg.num_train_gpus)),
                        )
            else:
                self.state.net_state.model.to(self.device)
                if self.state.value_net_state is not None:
                    self.state.value_net_state.model.to(self.device)

        if torch.cuda.is_available():
            pynvml.nvmlInit()

        if self.research and self.cfg.search_rollout.benchmark_only:
            self.perform_benchmark()
            return

        # Installing requeue handle at the last moment so that we don't requeue zombie job.
        heyhi.maybe_init_requeue_handler(self.on_requeue)

        self.run_training_loop()

    def run_training_loop(self):
        logging.info("Beginning training loop")
        max_epochs = self.cfg.trainer.max_epochs or 10 ** 9
        for self.state.epoch_id in range(self.state.epoch_id, max_epochs):
            # Clone state each epoch in case we'll need to requeue.
            if self.ectx.is_training_master:
                self.state.save(REQUEUE_CKPT)
                if (
                    self.cfg.trainer.save_checkpoint_every
                    and self.state.epoch_id % self.cfg.trainer.save_checkpoint_every == 0
                ):
                    self.state.net_state.save(CKPT_MAIN_DIR / (CKPT_TPL % self.state.epoch_id))
                    if self.state.value_net_state is not None:
                        self.state.value_net_state.save(
                            CKPT_VALUE_DIR / (CKPT_TPL % self.state.epoch_id)
                        )

            # Counter accumulate different statistic over the epoch. Default
            # accumulation strategy is averaging.
            counters = collections.defaultdict(fairdiplomacy.selfplay.metrics.FractionCounter)
            use_grad_clip = (self.cfg.optimizer.grad_clip or 0) > 1e-10
            if use_grad_clip:
                counters["optim/grad_max"] = fairdiplomacy.selfplay.metrics.MaxCounter()
            if not self.research:
                counters["score/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter()
                for p in POWERS:
                    counters[f"score_{p}/num_games"] = fairdiplomacy.selfplay.metrics.SumCounter()
            # For LR just record its value at the start of the epoch.
            counters["optim/lr"].update(next(iter(self.state.optimizer.param_groups))["lr"])
            if self.state.value_net_state is not None:
                counters["optim/lr_value"].update(
                    next(iter(self.state.value_net_state.optimizer.param_groups))["lr"]
                )
            epoch_start_time = time.time()
            for _ in range(self.cfg.trainer.epoch_size):
                self.do_step(counters=counters, use_grad_clip=use_grad_clip)
                if (
                    self.state.global_step < 128
                    or (self.state.global_step & self.state.global_step + 1) == 0
                ):
                    logging.info(
                        "Metrics (global_step=%d): %s",
                        self.state.global_step,
                        {k: v.value() for k, v in sorted(counters.items())},
                    )
                self.state.global_step += 1
            if self.research:
                for key, value in self.data_loader.get_buffer_stats(prefix="buffer/").items():
                    counters[key].update(value)

            epoch_scalars = {k: v.value() for k, v in sorted(counters.items())}
            average_batch_size = epoch_scalars["size/batch"]
            epoch_scalars["speed/loop_bps"] = self.cfg.trainer.epoch_size / (
                time.time() - epoch_start_time + 1e-5
            )
            epoch_scalars["speed/loop_eps"] = epoch_scalars["speed/loop_bps"] * average_batch_size
            # Speed for to_cuda + forward + backward.
            torch_time = epoch_scalars["time/net"] + epoch_scalars["time/to_cuda"]
            epoch_scalars["speed/train_bps"] = 1.0 / torch_time
            epoch_scalars["speed/train_eps"] = average_batch_size / torch_time

            if torch.cuda.is_available():
                for i in range(pynvml.nvmlDeviceGetCount()):
                    mem_info = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(i))
                    epoch_scalars[f"gpu_mem_used/{i}"] = mem_info.used / 2 ** 30
                    epoch_scalars[f"gpu_mem_free/{i}"] = mem_info.free / 2 ** 30

            mem_stats = psutil.virtual_memory()
            epoch_scalars["memory/used_gb"] = mem_stats.used / 2 ** 30
            epoch_scalars["memory/available_gb"] = mem_stats.available / 2 ** 30
            epoch_scalars["memory/free_gb"] = mem_stats.free / 2 ** 30

            eval_scores = self.data_loader.extract_eval_scores()
            if eval_scores is not None:
                for k, v in eval_scores.items():
                    epoch_scalars[f"score_eval/{k}"] = v

            logging.info(
                "Finished epoch %d. Metrics:\n%s",
                self.state.epoch_id,
                format_metrics_for_print(epoch_scalars),
            )
            self.logger.log_metrics(epoch_scalars, self.state.epoch_id)

            if self.state.scheduler is not None:
                self.state.scheduler.step()
            if self.state.value_net_state is not None and self.state.value_net_state.scheduler:
                self.state.value_net_state.scheduler.step()

        logging.info("End of training")
        logging.info("Exiting main funcion")

    def perform_benchmark(self):
        start = time.time()
        counters = collections.defaultdict(fairdiplomacy.selfplay.metrics.FractionCounter)
        for self.state.epoch_id in range(1000):
            batch: ReSearchRolloutBatch = self.data_loader.get_batch()  # type: ignore
            elapsed, start = time.time() - start, time.time()
            counters["speed/rps"].update(batch.done.sum(), elapsed)
            counters["speed/phps"].update(batch.done.numel(), elapsed)
            counters["rollout_len"].update(batch.done.numel(), batch.done.sum())
            epoch_scalars = {k: v.value() for k, v in sorted(counters.items())}
            if torch.cuda.is_available():
                for i in range(pynvml.nvmlDeviceGetCount()):
                    mem_info = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(i))
                    epoch_scalars[f"gpu_mem_used/{i}"] = mem_info.used / 2 ** 30
                    epoch_scalars[f"gpu_mem_free/{i}"] = mem_info.free / 2 ** 30
            logging.info(
                "Finished epoch %d. Metrics:\n%s",
                self.state.epoch_id,
                format_metrics_for_print(epoch_scalars),
            )
            self.logger.log_metrics(epoch_scalars, self.state.epoch_id)

    def do_step(self, **kwargs):
        if self.research:
            return self.do_step_research(**kwargs)
        else:
            return self.do_step_policy_gradient(**kwargs)

    def do_step_research(self, *, counters: collections.defaultdict, use_grad_clip: bool):
        device = self.device
        timings = TimingCtx()
        with timings("data_gen"):
            research_batch: ReSearchRolloutBatch = self.data_loader.get_batch()  # type: ignore

        assert (
            self.cfg.search_policy_update_prob >= 1.0 or self.cfg.value_update_prob >= 0
        ), "It does not make sense to subsample both"
        do_search_policy_loss = (
            self.is_policy_being_trained
            and self.random.random() <= self.cfg.search_policy_update_prob
        )
        do_value_loss = {
            self.is_value_being_trained and self.random.random() <= self.cfg.value_update_prob
        }

        main_net_has_grads = value_net_has_grads = False
        per_dataloader_batch_size = self.cfg.search_rollout.batch_size // self.ectx.ddp_world_size
        with timings("to_cuda"):
            rewards = research_batch.rewards.to(device)
            if self.cfg.search_rollout.buffer.shuffle:
                assert list(rewards.shape) == [
                    1,
                    self.cfg.search_rollout.chunk_length * per_dataloader_batch_size,
                    len(POWERS),
                ], rewards.shape
            else:
                assert list(rewards.shape) == [
                    self.cfg.search_rollout.chunk_length,
                    per_dataloader_batch_size,
                    len(POWERS),
                ], rewards.shape
            obs = {k: v.to(device) for k, v in research_batch.observations.items()}
            done = research_batch.done.to(device)
            is_search_policy_valid = research_batch.is_search_policy_valid.to(device)
            is_explore = research_batch.is_explore.to(device)
            is_dead = (research_batch.scores < 1e-3).float().to(device)
            targets = research_batch.targets.to(device)
            is_move_phase = (research_batch.phase_type == ord("M")).to(device)
            is_adj_phase = (research_batch.phase_type == ord("A")).to(device)
            years = research_batch.years.to(device)
            if is_explore.all():
                logging.warning("Whole batch of explore!!! Skipping")
                return
            if do_search_policy_loss:
                search_policy_probs, search_policy_orders, blueprint_probs = (
                    research_batch.search_policy_probs.to(device),
                    research_batch.search_policy_orders.long().to(device),
                    research_batch.blueprint_probs.to(device),
                )
                if self.cfg.search_ev_loss is not None:
                    search_policy_evs = research_batch.search_policy_evs.to(device)
            else:
                search_policy_probs = search_policy_orders = None

        is_all_powers = (
            self.state.model.module if hasattr(self.state.model, "module") else self.state.model
        ).is_all_powers()

        with timings("net"):
            loss = torch.tensor(0.0, device=device)
            losses = {}
            if do_value_loss:
                if self.state.value_net_state is not None:
                    value_net_has_grads = True
                else:
                    main_net_has_grads = True

                flat_obs = nest.map(lambda x: x.flatten(end_dim=1), obs)
                # Shape: [T, B, 7].
                _, _, _, predicted_values = self.state.value_model(
                    **flat_obs, temperature=1.0, need_policy=False,
                )
                predicted_values = predicted_values.reshape(rewards.shape)

                # Note, if you even change this, you have to propogate discounting
                # to search_data_loader akin to data_loader.
                assert self.cfg.discounting == 1.0, "Discounting is not supported for ReSearch"

                critic_mses = torch.nn.functional.mse_loss(
                    targets, predicted_values, reduction="none"
                )
                losses["critic"] = critic_mses.mean()
                loss = loss + self.cfg.critic_weight * losses["critic"]

            if do_search_policy_loss:
                main_net_has_grads = True
                if self.cfg.search_ev_loss is not None:
                    policy_loss_targets = evs_to_policy(
                        search_policy_evs,
                        temperature=self.cfg.search_ev_loss.temperature,
                        use_softmax=self.cfg.search_ev_loss.use_softmax,
                    )
                else:
                    policy_loss_targets = search_policy_probs

                # torch.save(
                #     (
                #         (
                #             self.state.model,
                #             obs,
                #             search_policy_orders,
                #             policy_loss_targets,
                #             blueprint_probs,
                #         ),
                #         dict(
                #             mask=is_search_policy_valid,
                #             is_move_phase=is_move_phase,
                #             is_adj_phase=is_adj_phase,
                #             using_ddp=self.cfg.use_distributed_data_parallel,
                #             max_prob_cap=self.cfg.search_policy_max_prob_cap,
                #         ),
                #     ),
                #     "/tmp/shuffle.pt",
                # )
                # raise 1

                (
                    search_policy_loss,
                    search_policy_metrics,
                ) = compute_search_policy_cross_entropy_sampled(
                    self.state.model,
                    obs,
                    search_policy_orders,
                    policy_loss_targets,
                    blueprint_probs,
                    mask=is_search_policy_valid,
                    is_move_phase=is_move_phase,
                    is_adj_phase=is_adj_phase,
                    using_ddp=self.cfg.use_distributed_data_parallel,
                    max_prob_cap=self.cfg.search_policy_max_prob_cap,
                    is_all_powers=is_all_powers,
                    power_conditioning=self.cfg.power_conditioning,
                    single_power_chances=self.cfg.single_power_chances,
                    six_power_chances=self.cfg.six_power_chances,
                )
                for k, v in search_policy_metrics.items():
                    counters[k].update(v)

                losses["search_policy"] = search_policy_loss
                loss = loss + search_policy_loss * self.cfg.search_policy_weight
                if (
                    self.cfg.sampled_entropy_weight is not None
                    and self.cfg.sampled_entropy_weight > 0.0
                ):
                    e_loss, e_mean = sample_and_compute_sampled_entropy_loss(
                        self.state.model, obs,
                    )
                    losses["policy_entropy_loss"] = e_loss
                    losses["policy_entropy"] = e_mean
                    loss = loss + e_loss * self.cfg.sampled_entropy_weight
            else:
                search_policy_loss = None

            self.state.optimizer.zero_grad()
            if self.state.value_net_state is not None:
                self.state.value_net_state.optimizer.zero_grad()

            loss.backward()

            if use_grad_clip:
                if main_net_has_grads:
                    g_norm_tensor = clip_grad_norm_(
                        self.state.model.parameters(), self.cfg.optimizer.grad_clip
                    )
                else:
                    g_norm_tensor = None
                if value_net_has_grads:
                    value_grad_clip = (self.cfg.value_optimizer or self.cfg.optimizer).grad_clip
                    g_norm_value_tensor = clip_grad_norm_(
                        self.state.value_net_state.model.parameters(), value_grad_clip
                    )
                else:
                    g_norm_value_tensor = None

            if (
                not self.cfg.trainer.max_updates
                or self.state.global_step < self.cfg.trainer.max_updates
            ):
                if main_net_has_grads:
                    self.state.net_state.optimizer.step()
                if value_net_has_grads:
                    self.state.value_net_state.optimizer.step()
            # Sync to make sure timing is correct.
            loss.item()

        with timings("metrics"), torch.no_grad():
            last_count = done.long().sum()

            time_bsz = rewards.shape[0] * rewards.shape[1]
            if do_value_loss:
                critic_end_mses = critic_mses[done].sum()

            if use_grad_clip:
                if g_norm_tensor is not None:
                    g_norm = g_norm_tensor.item()
                    counters["optim/grad_max"].update(g_norm)
                    counters["optim/grad_mean"].update(g_norm)
                    counters["optim/grad_clip_ratio"].update(
                        int(g_norm >= self.cfg.optimizer.grad_clip - 1e-5)
                    )
                if g_norm_value_tensor is not None:
                    g_norm_value = g_norm_value_tensor.item()
                    counters["optim/grad_value_max"].update(g_norm_value)
                    counters["optim/grad_value_mean"].update(g_norm_value)
                    counters["optim/grad_value_clip_ratio"].update(
                        int(g_norm_value >= value_grad_clip - 1e-5)
                    )
            for key, value in losses.items():
                counters[f"loss/{key}"].update(value)
            counters["loss/total"].update(loss.item())

            explored_on_the_right = research_batch.explored_on_the_right
            if do_value_loss:
                counters["loss/critic_no_explore"].update(
                    critic_mses[~explored_on_the_right].sum(), explored_on_the_right.long().sum()
                )
                counters["loss/critic_last"].update(critic_end_mses, last_count)
            counters["loss/is_explore"].update(is_explore.long().sum(), (1 - is_dead).sum())
            counters["loss/offpolicy_part"].update(
                explored_on_the_right.long().sum(), (1 - is_dead).sum()
            )
            counters["loss/is_search_policy_valid"].update(
                is_search_policy_valid.float().sum(), is_search_policy_valid.numel()
            )

            if do_search_policy_loss and not is_all_powers:
                # What's the entopy of the search policy.
                counters["loss/entropy_search"].update(
                    compute_search_policy_entropy(
                        search_policy_orders, search_policy_probs, mask=is_search_policy_valid
                    )
                )
                # is_search_policy_valid is indexed by [T, B, power]
                # is_move_phase is indexed [T, B]
                counters["loss/entropy_search_moves"].update(
                    compute_search_policy_entropy(
                        search_policy_orders,
                        search_policy_probs,
                        mask=is_search_policy_valid * is_move_phase.unsqueeze(2),
                    )
                )
                first_two_phases_move_mask = (
                    is_search_policy_valid
                    * is_move_phase.unsqueeze(2)
                    * (years == 1901).unsqueeze(2)
                )
                if first_two_phases_move_mask.any():
                    counters["loss/entropy_search_moves_1901"].update(
                        compute_search_policy_entropy(
                            search_policy_orders,
                            search_policy_probs,
                            mask=first_two_phases_move_mask,
                        )
                    )
                if self.cfg.search_ev_loss is not None:
                    counters["loss/entropy_search_from_evs"].update(
                        compute_search_policy_entropy(
                            search_policy_orders, policy_loss_targets, mask=is_search_policy_valid
                        )
                    )
                phase_bp_sums = blueprint_probs.flatten(end_dim=2)[
                    is_search_policy_valid.flatten(end_dim=2)
                ]
                phase_bp_sums = phase_bp_sums.sum(-1).view(-1)
                phase_bp_sums = phase_bp_sums[phase_bp_sums > 1e-10]
                counters["loss/bp_share"].update(phase_bp_sums.sum(), len(phase_bp_sums))

            counters["reward/mean"].update(rewards.sum(), time_bsz)
            # Rewards at the end of episodes. We precompute everything
            # before adding to counters to pipeline things when
            # possible.
            last_rewards = rewards[done]
            last_sum = last_rewards.sum()
            if self.is_value_being_trained:
                # Mean predicted value for dead powers.
                counters["value/mean_dead"].update(
                    (predicted_values * is_dead).sum(), is_dead.sum()
                )
            counters["reward/last"].update(last_sum, last_count)
            counters["reward_solo/last"].update((last_rewards > 0.9).float().sum(), last_count)
            for i, power in enumerate(POWERS):
                power_rewards = last_rewards[..., i]
                counters[f"reward/last_{power}"].update(power_rewards.sum(), last_count)
                counters[f"reward_solo/last_{power}"].update(
                    (power_rewards > 0.9).sum(), last_count
                )

            # Average batch size in phases, i.e., T * B.
            counters["size/batch"].update(time_bsz * self.ectx.ddp_world_size)
            if self.ectx.ddp_world_size > 1:
                counters["size/batch_local"].update(time_bsz)
            # Average number of phases per episode.
            counters["size/episode"].update(time_bsz, last_count)
            # Average number of move phases per episode.
            counters["size/episode_moves"].update(is_move_phase.float().sum(), last_count)
            # Average last year.
            counters["size/end_year"].update((years[done].float() - 1900).sum(), last_count)
            if search_policy_orders is not None:
                counters["size/search_policy_samples"].update(
                    (search_policy_orders != EOS_IDX).any(-1).any(2).float().sum(2).mean()
                )
                # search_policy_orders: [T, B, seven, max_actions, MAX_SEQ_LEN].
                # is_move_phase: [T, B].
                counters["size/search_policy_samples_moves"].update(
                    (search_policy_orders.flatten(end_dim=1)[is_move_phase.view(-1)] != EOS_IDX)
                    .any(-1)
                    .any(1)
                    .float()
                    .sum(-1)
                    .mean()
                )

        with timings("sync"), torch.no_grad():
            self.maybe_send_model_to_workers()

        # Doing outside of the context to capture the context's timing.
        for key, value in timings.items():
            counters[f"time/{key}"].update(value)

    def do_step_policy_gradient(self, *, counters: collections.defaultdict, use_grad_clip: bool):
        device = self.device
        timings = TimingCtx()
        with timings("data_gen"):
            (
                (power_ids, obs, rewards, actions, behavior_action_logprobs, done),
                rollout_scores_per_power,
            ) = self.data_loader.get_batch()

        with timings("to_cuda"):
            actions = actions.to(device)
            rewards = rewards.to(device)
            power_ids = power_ids.to(device)
            obs = {k: v.to(device) for k, v in obs.items()}
            cand_actions = obs.pop("cand_indices")
            behavior_action_logprobs = behavior_action_logprobs.to(device)
            done = done.to(device)

        with timings("net"):
            # Shape: _, [B, 17], [B, S, 469], [B, 7].
            # policy_cand_actions has the same information as actions,
            # but uses local indices to match policy logits.
            assert EOS_IDX == -1, "Rewrite the code to remove the assumption"
            _, _, policy_logits, sc_values = self.state.model(
                **obs,
                temperature=1.0,
                teacher_force_orders=actions.clamp(0),  # EOS_IDX = -1 -> 0
                x_power=power_ids.view(-1, 1).repeat(1, MAX_SEQ_LEN),
            )
            cand_actions = cand_actions[:, : policy_logits.shape[1]]

            # Shape: [B].
            sc_values = sc_values.gather(1, power_ids.unsqueeze(1)).squeeze(1)

            # Removing absolute order ids to not use them by accident.
            # Will use relative order ids (cand_actions) from now on.
            del actions

            if self.cfg.rollout.do_not_split_rollouts:
                # Asssumes that episode actually ends.
                bootstrap_value = torch.zeros_like(sc_values[-1])
            else:
                # Reducing batch size by one. Deleting things that are
                # too lazy to adjsut to avoid artifacts.
                bootstrap_value = sc_values[-1].detach()
                sc_values = sc_values[:-1]
                cand_actions = cand_actions[:-1]
                policy_logits = policy_logits[:-1]
                rewards = rewards[:-1]
                power_ids = power_ids[:-1]
                del obs
                behavior_action_logprobs = behavior_action_logprobs[:-1]
                done = done[:-1]

            # Shape: [B].
            discounts = (~done).float() * self.cfg.discounting

            # Shape: [B, 17].
            mask = (cand_actions != EOS_IDX).float()

            # Shape: [B].
            policy_action_logprobs = order_logits_to_action_logprobs(
                policy_logits, cand_actions, mask
            )

            vtrace_returns = vtrace_from_logprobs_no_batch(
                log_rhos=policy_action_logprobs - behavior_action_logprobs,
                discounts=discounts,
                rewards=rewards,
                values=sc_values,
                bootstrap_value=bootstrap_value,
            )

            critic_mses = 0.5 * ((vtrace_returns.vs.detach() - sc_values) ** 2)

            losses = dict(
                actor=compute_policy_gradient_loss(
                    policy_action_logprobs, vtrace_returns.pg_advantages
                ),
                critic=critic_mses.mean(),
                entropy=compute_entropy_loss(policy_logits, mask),
            )

            loss = (
                losses["actor"]
                + self.cfg.critic_weight * losses["critic"]
                + self.cfg.entropy_weight * losses["entropy"]
            )
            if self.cfg.sampled_entropy_weight:
                loss = loss + self.cfg.sampled_entropy_weight * compute_sampled_entropy_loss(
                    policy_action_logprobs
                )

            self.state.optimizer.zero_grad()
            loss.backward()

            if use_grad_clip:
                g_norm_tensor = clip_grad_norm_(
                    self.state.model.parameters(), self.cfg.optimizer.grad_clip
                )

            if (
                not self.cfg.trainer.max_updates
                or self.state.global_step < self.cfg.trainer.max_updates
            ):
                self.state.optimizer.step()
            # Sync to make sure timing is correct.
            loss.item()

        with timings("metrics"), torch.no_grad():
            last_count = done.long().sum()
            critic_end_mses = critic_mses[done].sum()

            if use_grad_clip:
                g_norm = g_norm_tensor.item()
                counters["optim/grad_max"].update(g_norm)
                counters["optim/grad_mean"].update(g_norm)
                counters["optim/grad_clip_ratio"].update(
                    int(g_norm >= self.cfg.optimizer.grad_clip - 1e-5)
                )
            for key, value in losses.items():
                counters[f"loss/{key}"].update(value)
            counters["loss/total"].update(loss.item())
            for power_id, rollout_scores in rollout_scores_per_power.items():
                prefix = f"score_{POWERS[power_id]}" if power_id is not None else "score"
                for key, value in rollout_scores.items():
                    if key != "num_games":
                        counters[f"{prefix}/{key}"].update(value, rollout_scores["num_games"])
                    else:
                        counters[f"{prefix}/{key}"].update(value)

            counters["loss/critic_last"].update(critic_end_mses, last_count)

            counters["reward/mean"].update(rewards.sum(), len(rewards))
            # Rewards at the end of episodes. We precompute everything
            # before adding to counters to pipeline things when
            # possible.
            last_rewards = rewards[done]
            last_sum = last_rewards.sum()
            # tensor [num_powers, num_dones].
            last_power_masks = (
                power_ids[done].unsqueeze(0)
                == torch.arange(len(POWERS), device=power_ids.device).unsqueeze(1)
            ).float()
            last_power_rewards = (last_power_masks * last_rewards.unsqueeze(0)).sum(1)
            last_power_counts = last_power_masks.sum(1)
            counters["reward/last"].update(last_sum, last_count)
            for power, reward, counts in zip(
                POWERS, last_power_rewards.cpu(), last_power_counts.cpu()
            ):
                counters[f"reward/last_{power}"].update(reward, counts)
            # To match entropy loss we don't negate logprobs. So this
            # is an estimate of the negative entropy.
            counters["loss/entropy_sampled"].update(policy_action_logprobs.mean())

            # Measure off-policiness.
            counters["loss/rho"].update(vtrace_returns.rhos.sum(), vtrace_returns.rhos.numel())
            counters["loss/rhos_clipped"].update(
                vtrace_returns.clipped_rhos.sum(), vtrace_returns.clipped_rhos.numel()
            )

            bsz = len(rewards)
            counters["size/batch"].update(bsz)
            counters["size/episode"].update(bsz, last_count)

        with timings("sync"), torch.no_grad():
            self.maybe_send_model_to_workers()

        # Doing outside of the context to capture the context's timing.
        for key, value in timings.items():
            counters[f"time/{key}"].update(value)


def format_metrics_for_print(metrics: Dict[str, float]) -> str:
    lines = []
    for key, pairs in itertools.groupby(sorted(metrics.items()), key=lambda x: x[0].split("/")[0]):
        lines.append(f"  {key}/\n")
        pairs = " ".join("=".join([k.split("/", 1)[-1], ("%g" % v)]) for k, v in pairs)
        lines.append(f"    {pairs}\n")
    return "".join(lines)


def training_helper(cfg, ectx, random_seed):
    heyhi.setup_logging(label=("train%d" % ectx.training_ddp_rank))
    trainer = ExploitTrainer(cfg=cfg, ectx=ectx, random_seed=random_seed)
    try:
        trainer()
    finally:
        trainer.terminate()


def task(cfg):
    random_seed = random.randrange(2 ** 63)
    training_helpers = []

    ddp_world_size = cfg.num_train_gpus if cfg.use_distributed_data_parallel else 1

    def make_ectx(training_ddp_rank: Optional[int]) -> ExecutionContext:
        return ExecutionContext(
            training_ddp_rank=training_ddp_rank,
            using_ddp=cfg.use_distributed_data_parallel,
            ddp_world_size=ddp_world_size,
        )

    if heyhi.is_master():
        rendezvous_path = get_rendezvous_path()
        if rendezvous_path.exists():
            logging.info("Removing old rendezvous dir")
            shutil.rmtree(rendezvous_path)
        rendezvous_path.mkdir(parents=True)
        heyhi.setup_logging(label="master")
        if cfg.use_distributed_data_parallel:
            for training_ddp_rank in range(1, ddp_world_size):
                logging.info(
                    f"Spawning training helper {training_ddp_rank} for distributed data parallel"
                )
                process = ExceptionHandlingProcess(
                    target=training_helper,
                    args=[cfg, make_ectx(training_ddp_rank), random_seed],
                    daemon=True,
                )
                process.start()
                training_helpers.append(process)
            logging.info("Spawning complete, waiting for helpers to sync up")

        trainer = ExploitTrainer(cfg=cfg, ectx=make_ectx(0), random_seed=random_seed)
    else:
        logging.info("Delaying non-master machines by 10 seconds to setup rendezvous")
        time.sleep(10)
        heyhi.setup_logging(label="datagen")
        trainer = ExploitTrainer(cfg=cfg, ectx=make_ectx(None), random_seed=random_seed)

    try:
        trainer()
    finally:
        trainer.terminate()
        for process in training_helpers:
            logging.info("Killing training helper")
            process.kill()


def _lstm_to_train(module):
    if module.__class__.__name__ == "LSTM":
        module.train()


def _bn_to_train(module):
    if "BatchNorm" in module.__class__.__name__:
        module.train()


def _warmup_decay_schedule(warmup_decay: conf.conf_cfgs.ExploitTask.Optimization.WarmupDecay):
    assert warmup_decay.warmup_epochs is not None
    assert warmup_decay.decay_epochs is not None
    assert 0 < warmup_decay.warmup_epochs < warmup_decay.decay_epochs, warmup_decay

    def schedule(epoch):
        if epoch <= warmup_decay.warmup_epochs:
            return max(1, epoch) / warmup_decay.warmup_epochs
        alpha = (epoch - warmup_decay.warmup_epochs) / (
            warmup_decay.decay_epochs - warmup_decay.warmup_epochs
        )
        alpha = min(alpha, 1.0)

        initial_multiplier = 1.0
        final_multiplier = warmup_decay.final_decay
        return initial_multiplier * (1 - alpha) + final_multiplier * alpha

    return schedule
